{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Generation\n",
    "\n",
    "This is a simple example of dataset generation using WebDataset `TarWriter`. Shard are uploaded to a server or to the cloud as they are generated.\n",
    "\n",
    "Parallel dataset generation with Ray is illustrated at the very end.\n",
    "\n",
    "This particular notebook generates short text samples using GPT-2. These can be used to generate OCR training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "61u4BASSNq6y"
   },
   "outputs": [],
   "source": [
    "# package installs for colab\n",
    "\n",
    "import sys\n",
    "\n",
    "if \"google.colab\" in sys.modules:\n",
    "    !pip install --quiet webdataset\n",
    "    !pip install --quiet adapter-transformers\n",
    "    !pip install --quiet sentencepiece\n",
    "    !pip install --quiet datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!huggingface-cli whoami"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 243,
     "referenced_widgets": [
      "5011003a97364ffbaf0ba5c49f40a856",
      "8a24d38e49e845e68f949fd7e950c547",
      "11d4db32b02f45b89455ae3c7c1b48e7",
      "f68aa4ffbe37471ab728a5bfdd96a162",
      "a53f2b82f0c5473f887c87415441c851",
      "f6a5e728da38475fbbda774a43f529fb",
      "5e417069faed471392193db141b2851b",
      "4cb67b0c60bb4b1ba719897a2baac500",
      "1dc2b10916f14ee58c92be8159cf2e1e",
      "dcf507a023d24a29b56e871f7ddb4c3b",
      "c8a631d4c2bd4b2a91d54ee394c8fc71"
     ]
    },
    "id": "jrEQw7TXLqxC",
    "outputId": "8957bc2d-b768-4ad9-edba-d610f805b071"
   },
   "outputs": [],
   "source": [
    "import uuid\n",
    "import webdataset as wds\n",
    "import os\n",
    "\n",
    "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
    "from transformers import pipeline\n",
    "import textwrap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "nsamples = 10\n",
    "ntokens = 100\n",
    "nshards = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# text generation with Huggingface and GPT2\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
    "generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
    "\n",
    "\n",
    "def generate(n, prompt=\"\"):\n",
    "    \"\"\"Generate n words of text, starting with prompt.\"\"\"\n",
    "    global tokenizer, model, generator\n",
    "    output = generator(\n",
    "        prompt,\n",
    "        max_length=n + len(tokenizer.encode(prompt)),\n",
    "        do_sample=True,\n",
    "        temperature=0.99,\n",
    "        top_k=50,\n",
    "        top_p=0.99,\n",
    "        num_return_sequences=1,\n",
    "        pad_token_id=tokenizer.eos_token_id,\n",
    "    )[0]\n",
    "    return output[\"generated_text\"]\n",
    "\n",
    "\n",
    "text = generate(100).strip()\n",
    "print()\n",
    "print(textwrap.fill(text, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ElUfC0IbaPnR"
   },
   "outputs": [],
   "source": [
    "# function generating an entire shard using TarWriter\n",
    "\n",
    "\n",
    "def generate_shard(oname, nsamples=10000, ntokens=500, prefix=\"\"):\n",
    "    \"\"\"Generate a shard of samples with text.\n",
    "\n",
    "    Each sample has a \"__key__\" field and a \"txt.gz\" field.\n",
    "    That is, the individual text files are compressed automatically on write.\n",
    "    They will be automatically decompressed when read.\n",
    "    \"\"\"\n",
    "    with wds.TarWriter(oname) as output:\n",
    "        for i in range(nsamples):\n",
    "            text = generate(100).strip()\n",
    "            key = uuid.uuid4().hex\n",
    "            text = generate(ntokens)\n",
    "            sample = {\"__key__\": key, \"txt.gz\": text}\n",
    "            output.write(sample)\n",
    "            if i % 10 == 0:\n",
    "                print(f\"{i:6d} {prefix}:\", repr(text)[:60])\n",
    "\n",
    "\n",
    "generate_shard(\"temp.tar\", nsamples=10, ntokens=10)\n",
    "!ls -l temp.tar\n",
    "!tar tf temp.tar | head -5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u6I69B4FbPbk"
   },
   "outputs": [],
   "source": [
    "# We need a couple of simple functions to upload to the cloud.\n",
    "\n",
    "\n",
    "def cloud_exists(oname):\n",
    "    \"\"\"Check whether a file exists in the cloud.\"\"\"\n",
    "    # return os.system(f\"gsutil stat gs://mybucket/500tokens/{oname}\") == 0\n",
    "    return True\n",
    "\n",
    "\n",
    "def cloud_upload(oname):\n",
    "    \"\"\"Upload a file to the cloud.\"\"\"\n",
    "    # assert os.system(f\"gsutil cp {oname} gs://mybucket/500tokens/{oname}\") == 0\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can now generate a shard and upload it to the cloud.\n",
    "# We skip the generation if the file already exists in the cloud.\n",
    "\n",
    "\n",
    "def generate_and_upload(i):\n",
    "    \"\"\"Generate a shard and upload it to the cloud.\"\"\"\n",
    "    oname = f\"text-{i:06d}.tar\"\n",
    "    if cloud_exists(oname):\n",
    "        print(f\"{oname} already exists, skipping\")\n",
    "        return False\n",
    "    generate_shard(oname, nsamples=nsamples, ntokens=ntokens, prefix=f\"{i:6d} {oname}\")\n",
    "    cloud_upload(oname)\n",
    "    os.remove(oname)\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For sequential generation, use this\n",
    "\n",
    "for i in range(nshards):\n",
    "    generate_and_upload(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rjtefVDI-DfZ"
   },
   "outputs": [],
   "source": [
    "%%script true\n",
    "# For parallel generation, use this\n",
    "\n",
    "import ray\n",
    "\n",
    "@ray.remote(num_cpus=1, num_gpus=1)\n",
    "def ray_generate_and_upload(i):\n",
    "    \"\"\"A Ray remote function that generates a shard and uploads it to the cloud.\"\"\"\n",
    "    return generate_and_upload(i)\n",
    "\n",
    "def generate_shards(nshards=10):\n",
    "    \"\"\"Generate a number of shards and upload them to the cloud.\n",
    "    \n",
    "    Runs in parallel on a Ray cluster.\n",
    "    \"\"\"\n",
    "    ray.init(address='auto')  # Connect to the Ray cluster\n",
    "    tasks = [ray_generate_and_upload.remote(i) for i in range(nshards)]\n",
    "    ray.shutdown()\n",
    "    return shard_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "11d4db32b02f45b89455ae3c7c1b48e7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4cb67b0c60bb4b1ba719897a2baac500",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_1dc2b10916f14ee58c92be8159cf2e1e",
      "value": 0
     }
    },
    "1dc2b10916f14ee58c92be8159cf2e1e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "4cb67b0c60bb4b1ba719897a2baac500": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": "20px"
     }
    },
    "5011003a97364ffbaf0ba5c49f40a856": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_8a24d38e49e845e68f949fd7e950c547",
       "IPY_MODEL_11d4db32b02f45b89455ae3c7c1b48e7",
       "IPY_MODEL_f68aa4ffbe37471ab728a5bfdd96a162"
      ],
      "layout": "IPY_MODEL_a53f2b82f0c5473f887c87415441c851"
     }
    },
    "5e417069faed471392193db141b2851b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "8a24d38e49e845e68f949fd7e950c547": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f6a5e728da38475fbbda774a43f529fb",
      "placeholder": "​",
      "style": "IPY_MODEL_5e417069faed471392193db141b2851b",
      "value": ""
     }
    },
    "a53f2b82f0c5473f887c87415441c851": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c8a631d4c2bd4b2a91d54ee394c8fc71": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "dcf507a023d24a29b56e871f7ddb4c3b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f68aa4ffbe37471ab728a5bfdd96a162": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_dcf507a023d24a29b56e871f7ddb4c3b",
      "placeholder": "​",
      "style": "IPY_MODEL_c8a631d4c2bd4b2a91d54ee394c8fc71",
      "value": " 0/0 [00:00&lt;?, ?it/s]"
     }
    },
    "f6a5e728da38475fbbda774a43f529fb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
